{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 创建 TVM 的 NDArray 的子类\n",
    "\n",
    "::::{dropdown} C++ 源码：\n",
    "```{literalinclude} src/testing/NDSubClass.cc\n",
    ":language: C++\n",
    "```\n",
    "::::"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "编译："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "g++ -std=c++17 -O2 -fPIC -I/media/pc/data/lxw/ai/tvm/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dmlc-core/include -I/media/pc/data/lxw/ai/tvm/3rdparty/dlpack/include -Iinclude -DDMLC_USE_LOGGING_LIBRARY=\\<tvm/runtime/logging.h\\> -shared -o outputs/libs/libtvm_NDSubClass.so src/testing/NDSubClass.cc -ldl -pthread -L/media/pc/data/lxw/ai/tvm/build\n"
     ]
    }
   ],
   "source": [
    "!make outputs/libs/libtvm_NDSubClass.so"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `tvm_ext.ivec_create`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm_book.tvm_ext.libinfo import load_lib\n",
    "\n",
    "_LIB, _LIB_NAME = load_lib(name=\"libtvm_NDSubClass.so\", search_path=[\"outputs/libs\"])\n",
    "tvm._ffi._init_api(\"tvm_ext\", __name__)\n",
    "\n",
    "ivec_create = tvm.get_global_func(\"tvm_ext.ivec_create\")\n",
    "ivec_get = tvm.get_global_func(\"tvm_ext.ivec_get\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "要使用此插件，外部库应执行以下操作：\n",
    "\n",
    "1. 继承 TVM 的 NDArray 和 NDArray 容器；\n",
    "2. 遵循新的对象协议以将新 NDArray 定义为引用类。\n",
    "3. 在 Python 前端上，继承 `tvm.nd.NDArray`，并使用 `tvm.register_object` 注册类型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tvm.register_object(\"tvm_ext.IntVector\")\n",
    "class IntVec(tvm.Object):\n",
    "    \"\"\"Example for using extension class in c++\"\"\"\n",
    "\n",
    "    @property\n",
    "    def _tvm_handle(self):\n",
    "        return self.handle.value\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return ivec_get(self, idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "ivec = ivec_create(1, 2, 3)\n",
    "assert isinstance(ivec, IntVec)\n",
    "assert ivec[0] == 1\n",
    "assert ivec[1] == 2\n",
    "\n",
    "def ivec_cb(v2):\n",
    "    assert isinstance(v2, IntVec)\n",
    "    assert v2[2] == 3\n",
    "\n",
    "tvm.runtime.convert(ivec_cb)(ivec)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `tvm_ext.NDSubClass`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "nd_create = tvm.get_global_func(\"tvm_ext.nd_create\")\n",
    "nd_add_two = tvm.get_global_func(\"tvm_ext.nd_add_two\")\n",
    "nd_get_additional_info = tvm.get_global_func(\"tvm_ext.nd_get_additional_info\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tvm.register_object(\"tvm_ext.NDSubClass\")\n",
    "class NDSubClass(tvm.nd.NDArrayBase):\n",
    "    \"\"\"Example for subclassing TVM's NDArray infrastructure.\n",
    "\n",
    "    By inheriting TVM's NDArray, external libraries could\n",
    "    leverage TVM's FFI without any modification.\n",
    "    \"\"\"\n",
    "\n",
    "    @staticmethod\n",
    "    def create(additional_info):\n",
    "        return nd_create(additional_info)\n",
    "\n",
    "    @property\n",
    "    def additional_info(self):\n",
    "        return nd_get_additional_info(self)\n",
    "\n",
    "    def __add__(self, other):\n",
    "        return nd_add_two(self, other)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = NDSubClass.create(additional_info=3)\n",
    "b = NDSubClass.create(additional_info=5)\n",
    "assert isinstance(a, NDSubClass)\n",
    "c = a + b\n",
    "d = a + a\n",
    "e = b + b\n",
    "assert a.additional_info == 3\n",
    "assert b.additional_info == 5\n",
    "assert c.additional_info == 8\n",
    "assert d.additional_info == 6\n",
    "assert e.additional_info == 10"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
